{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "HW9_XAI.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Explainable AI\n",
        "作業說明投影片\n",
        "\n",
        "Homework Introduction: https://docs.google.com/presentation/d/1KSvfraupHDBnBriVqWCRpB0554rV66SOcaqdtFtqIkk/edit?usp=sharing\n",
        "\n",
        "本作業不提供 python script 版本\n",
        "\n",
        "There is no python script version for this homework\n",
        "\n",
        "若有任何問題，歡迎來信至助教信箱： mlta-2022-spring@googlegroups.com\n",
        "\n",
        "If you have any question, TA's mail: mlta-2022-spring@googlegroups.com\n",
        "\n",
        "## Deadline\n",
        "Mandarin/English: 4/29 release, 5/20 due"
      ],
      "metadata": {
        "id": "h79sv7nH1zlE"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# **Homework 9 - Explainable AI (Part 1 CNN)**"
      ],
      "metadata": {
        "id": "GLm4rv5y2sBx"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Connect to google drive"
      ],
      "metadata": {
        "id": "ZUKTWPXW20wk"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GVsbrnFf5nn1"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive', force_remount=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z6mlsay85931"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.makedirs('/content/gdrive/My Drive/MLHW_XAI', exist_ok=True)\n",
        "os.chdir('/content/gdrive/My Drive/MLHW_XAI')\n",
        "!ls"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Environment Settings"
      ],
      "metadata": {
        "id": "N8wlNl743b6o"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yFilmmPP3pRT"
      },
      "outputs": [],
      "source": [
        "# download and unzip training data\n",
        "!gdown --id '1QntUQuWJoVR8h5FoeDa56xrQSdcCwFeD' --output food.zip\n",
        "!unzip food.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zhzdomRTOKoJ"
      },
      "outputs": [],
      "source": [
        "# download pretrained model\n",
        "!gdown --id '1-Qw-oIJ0cSo2iG_n_U9mcJqXc2-LCSdV' --output checkpoint.pth"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kP4zsXYaI47z"
      },
      "outputs": [],
      "source": [
        "# install lime in colab\n",
        "!pip install lime==0.1.1.37"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "leWC9kKgL55n"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import argparse\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import Adam\n",
        "from torch.utils.data import Dataset\n",
        "import torchvision.transforms as transforms\n",
        "from skimage.segmentation import slic\n",
        "from lime import lime_image\n",
        "from pdb import set_trace\n",
        "from torch.autograd import Variable"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Argument Parsing"
      ],
      "metadata": {
        "id": "GVQyXl9v_c_9"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kbVhzRFp8dnu"
      },
      "outputs": [],
      "source": [
        "args = {\n",
        "      'ckptpath': './checkpoint.pth',\n",
        "      'dataset_dir': './food/'\n",
        "}\n",
        "args = argparse.Namespace(**args)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Model Definition and Checkpoint Loading"
      ],
      "metadata": {
        "id": "HnoTVVRq_toZ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iqXQTwoxeRbO"
      },
      "outputs": [],
      "source": [
        "# Model definition\n",
        "class Classifier(nn.Module):\n",
        "  def __init__(self):\n",
        "    super(Classifier, self).__init__()\n",
        "\n",
        "    def building_block(indim, outdim):\n",
        "      return [\n",
        "        nn.Conv2d(indim, outdim, 3, 1, 1),\n",
        "        nn.BatchNorm2d(outdim),\n",
        "        nn.ReLU(),\n",
        "      ]\n",
        "    def stack_blocks(indim, outdim, block_num):\n",
        "      layers = building_block(indim, outdim)\n",
        "      for i in range(block_num - 1):\n",
        "        layers += building_block(outdim, outdim)\n",
        "      layers.append(nn.MaxPool2d(2, 2, 0))\n",
        "      return layers\n",
        "\n",
        "    cnn_list = []\n",
        "    cnn_list += stack_blocks(3, 128, 3)\n",
        "    cnn_list += stack_blocks(128, 128, 3)\n",
        "    cnn_list += stack_blocks(128, 256, 3)\n",
        "    cnn_list += stack_blocks(256, 512, 1)\n",
        "    cnn_list += stack_blocks(512, 512, 1)\n",
        "    self.cnn = nn.Sequential( * cnn_list)\n",
        "\n",
        "    dnn_list = [\n",
        "      nn.Linear(512 * 4 * 4, 1024),\n",
        "      nn.ReLU(),\n",
        "      nn.Dropout(p = 0.3),\n",
        "      nn.Linear(1024, 11),\n",
        "    ]\n",
        "    self.fc = nn.Sequential( * dnn_list)\n",
        "\n",
        "  def forward(self, x):\n",
        "    out = self.cnn(x)\n",
        "    out = out.reshape(out.size()[0], -1)\n",
        "    return self.fc(out)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "er56l_dOAKBO"
      },
      "outputs": [],
      "source": [
        "# Load trained model\n",
        "model = Classifier().cuda()\n",
        "checkpoint = torch.load(args.ckptpath)\n",
        "model.load_state_dict(checkpoint['model_state_dict'])\n",
        "# It should display: <All keys matched successfully> "
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Dataset Definition and Creation"
      ],
      "metadata": {
        "id": "GFilFnNYAExt"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_iKBcxwa_Rpl"
      },
      "outputs": [],
      "source": [
        "# It might take some time, if it is too long, try to reload it.\n",
        "# Dataset definition\n",
        "class FoodDataset(Dataset):\n",
        "    def __init__(self, paths, labels, mode):\n",
        "        # mode: 'train' or 'eval'\n",
        "        \n",
        "        self.paths = paths\n",
        "        self.labels = labels\n",
        "        trainTransform = transforms.Compose([\n",
        "            transforms.Resize(size=(128, 128)),\n",
        "            transforms.RandomHorizontalFlip(),\n",
        "            transforms.RandomRotation(15),\n",
        "            transforms.ToTensor(),\n",
        "        ])\n",
        "        evalTransform = transforms.Compose([\n",
        "            transforms.Resize(size=(128, 128)),\n",
        "            transforms.ToTensor(),\n",
        "        ])\n",
        "        self.transform = trainTransform if mode == 'train' else evalTransform\n",
        "\n",
        "    # pytorch dataset class\n",
        "    def __len__(self):\n",
        "        return len(self.paths)\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        X = Image.open(self.paths[index])\n",
        "        X = self.transform(X)\n",
        "        Y = self.labels[index]\n",
        "        return X, Y\n",
        "\n",
        "    # help to get images for visualizing\n",
        "    def getbatch(self, indices):\n",
        "        images = []\n",
        "        labels = []\n",
        "        for index in indices:\n",
        "          image, label = self.__getitem__(index)\n",
        "          images.append(image)\n",
        "          labels.append(label)\n",
        "        return torch.stack(images), torch.tensor(labels)\n",
        "\n",
        "# help to get data path and label\n",
        "def get_paths_labels(path):\n",
        "    def my_key(name):\n",
        "      return int(name.replace(\".jpg\",\"\").split(\"_\")[1])+1000000*int(name.split(\"_\")[0])\n",
        "    imgnames = os.listdir(path)\n",
        "    imgnames.sort(key=my_key)\n",
        "    imgpaths = []\n",
        "    labels = []\n",
        "    for name in imgnames:\n",
        "        imgpaths.append(os.path.join(path, name))\n",
        "        labels.append(int(name.split('_')[0]))\n",
        "    return imgpaths, labels\n",
        "train_paths, train_labels = get_paths_labels(args.dataset_dir)\n",
        "\n",
        "train_set = FoodDataset(train_paths, train_labels, mode='eval')"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## The Images for Observation\n",
        "There are 11 categories of food: Bread, Dairy product, Dessert, Egg, Fried food, Meat, Noodles/Pasta, Rice, Seafood, Soup, and Vegetable/Fruit.\n",
        "Be sure that the images shown here are **the same as** those in the slides.\n",
        "The images are marked from 0 to 9."
      ],
      "metadata": {
        "id": "DHj9yvMsAumF"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zKFmM1sacjyG"
      },
      "outputs": [],
      "source": [
        "img_indices = [i for i in range(10)]\n",
        "images, labels = train_set.getbatch(img_indices)\n",
        "fig, axs = plt.subplots(1, len(img_indices), figsize=(15, 8))\n",
        "for i, img in enumerate(images):\n",
        "  axs[i].imshow(img.cpu().permute(1, 2, 0))\n",
        "# print(labels) # this line can help you know the labels of each image"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Lime (Q1~4) \n",
        "[Lime](https://github.com/marcotcr/lime) is a package about explaining what machine learning classifiers are doing. We can first use it to observe the model."
      ],
      "metadata": {
        "id": "nGGVJx9IBe8F"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RI6e9_68HvQe"
      },
      "outputs": [],
      "source": [
        "def predict(input):\n",
        "    # input: numpy array, (batches, height, width, channels)                                                                                                                                                     \n",
        "    \n",
        "    model.eval()                                                                                                                                                             \n",
        "    input = torch.FloatTensor(input).permute(0, 3, 1, 2)                                                                                                            \n",
        "    # pytorch tensor, (batches, channels, height, width)\n",
        "\n",
        "    output = model(input.cuda())                                                                                                                                             \n",
        "    return output.detach().cpu().numpy()                                                                                                                              \n",
        "                                                                                                                                                                             \n",
        "def segmentation(input):\n",
        "    # split the image into 200 pieces with the help of segmentaion from skimage                                                                                                                   \n",
        "    return slic(input, n_segments=200, compactness=1, sigma=1, start_label=1)                                                                                                              \n",
        "                                                                                                                                                                             \n",
        "\n",
        "fig, axs = plt.subplots(1, len(img_indices), figsize=(15, 8))                                                                                                                                                                 \n",
        "# fix the random seed to make it reproducible\n",
        "np.random.seed(16)                                                                                                                                                       \n",
        "for idx, (image, label) in enumerate(zip(images.permute(0, 2, 3, 1).numpy(), labels)):                                                                                                                                             \n",
        "    x = image.astype(np.double)\n",
        "    # numpy array for lime\n",
        "\n",
        "    explainer = lime_image.LimeImageExplainer()                                                                                                                              \n",
        "    explaination = explainer.explain_instance(image=x, classifier_fn=predict, segmentation_fn=segmentation)\n",
        "\n",
        "    # doc: https://lime-ml.readthedocs.io/en/latest/lime.html?highlight=explain_instance#lime.lime_image.LimeImageExplainer.explain_instance\n",
        "\n",
        "    lime_img, mask = explaination.get_image_and_mask(                                                                                                                         \n",
        "                                label=label.item(),                                                                                                                           \n",
        "                                positive_only=False,                                                                                                                         \n",
        "                                hide_rest=False,                                                                                                                             \n",
        "                                num_features=11,                                                                                                                              \n",
        "                                min_weight=0.05                                                                                                                            \n",
        "                            )\n",
        "    # turn the result from explainer to the image\n",
        "    # doc: https://lime-ml.readthedocs.io/en/latest/lime.html?highlight=get_image_and_mask#lime.lime_image.ImageExplanation.get_image_and_mask\n",
        "    \n",
        "    axs[idx].imshow(lime_img)\n",
        "\n",
        "plt.show()\n",
        "plt.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Saliency Map (Q5~9)"
      ],
      "metadata": {
        "id": "zctMN1bdCTVY"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "What is Saliency map ?\n",
        "\n",
        "Saliency: 顯著性\n",
        "\n",
        "The heatmaps that highlight pixels of the input image that contribute the most in the classification task.\n",
        "\n",
        "Ref: https://medium.com/datadriveninvestor/visualizing-neural-networks-using-saliency-maps-in-pytorch-289d8e244ab4"
      ],
      "metadata": {
        "id": "EYDU5l9PCc6B"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We put an image into the model, forward then calculate the loss referring to the label. Therefore, the loss is related to:\n",
        "\n",
        "\n",
        "*   image\n",
        "*   model parameters\n",
        "*   label\n",
        "\n",
        "Generally speaking, we change model parameters to fit \"image\" and \"label\". When backward, we calculate the partial differential value of **loss to model parameters**.\n",
        "\n",
        "Now, we have another look. When we change the image's pixel value, the partial differential value of **loss to image** shows the change in the loss. We can say that it means the importance of the pixel. We can visualize it to demonstrate which part of the image contribute the most to the model's judgment."
      ],
      "metadata": {
        "id": "SazwNhS4ChoO"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "swIZSW7O04O-"
      },
      "outputs": [],
      "source": [
        "def normalize(image):\n",
        "  return (image - image.min()) / (image.max() - image.min())\n",
        "  # return torch.log(image)/torch.log(image.max())\n",
        "\n",
        "def compute_saliency_maps(x, y, model):\n",
        "  model.eval()\n",
        "  x = x.cuda()\n",
        "\n",
        "  # we want the gradient of the input x\n",
        "  x.requires_grad_()\n",
        "  \n",
        "  y_pred = model(x)\n",
        "  loss_func = torch.nn.CrossEntropyLoss()\n",
        "  loss = loss_func(y_pred, y.cuda())\n",
        "  loss.backward()\n",
        "\n",
        "  # saliencies = x.grad.abs().detach().cpu()\n",
        "  saliencies, _ = torch.max(x.grad.data.abs().detach().cpu(),dim=1)\n",
        "\n",
        "  # We need to normalize each image, because their gradients might vary in scale\n",
        "  saliencies = torch.stack([normalize(item) for item in saliencies])\n",
        "  return saliencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S_w8iSe319Ws"
      },
      "outputs": [],
      "source": [
        "# images, labels = train_set.getbatch(img_indices)\n",
        "saliencies = compute_saliency_maps(images, labels, model)\n",
        "\n",
        "# visualize\n",
        "fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))\n",
        "for row, target in enumerate([images, saliencies]):\n",
        "  for column, img in enumerate(target):\n",
        "    if row==0:\n",
        "      axs[row][column].imshow(img.permute(1, 2, 0).numpy())\n",
        "      # What is permute?\n",
        "      # In pytorch, the meaning of each dimension of image tensor is (channels, height, width)\n",
        "      # In matplotlib, the meaning of each dimension of image tensor is (height, width, channels)\n",
        "      # permute is a tool for permuting dimensions of tensors\n",
        "      # For example, img.permute(1, 2, 0) means that,\n",
        "      # - 0 dimension is the 1 dimension of the original tensor, which is height\n",
        "      # - 1 dimension is the 2 dimension of the original tensor, which is width\n",
        "      # - 2 dimension is the 0 dimension of the original tensor, which is channels\n",
        "    else:\n",
        "      axs[row][column].imshow(img.numpy(), cmap=plt.cm.hot)\n",
        "    \n",
        "plt.show()\n",
        "plt.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Smooth Grad (Q10~13)\n",
        "Smooth grad 的方法是，在圖片中隨機地加入 noise，然後得到不同的 heatmap，把這些 heatmap 平均起來就得到一個比較能抵抗 noisy gradient 的結果。\n",
        "\n",
        "The method of Smooth grad is to randomly add noise to the image and get different heatmaps. The average of the heatmaps would be more robust to noisy gradient.\n",
        "\n",
        "ref: https://arxiv.org/pdf/1706.03825.pdf"
      ],
      "metadata": {
        "id": "IAslTKyuDIvq"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EjusfKjISm-k"
      },
      "outputs": [],
      "source": [
        "# Smooth grad\n",
        "\n",
        "def normalize(image):\n",
        "  return (image - image.min()) / (image.max() - image.min())\n",
        "\n",
        "def smooth_grad(x, y, model, epoch, param_sigma_multiplier):\n",
        "  model.eval()\n",
        "  #x = x.cuda().unsqueeze(0)\n",
        "\n",
        "  mean = 0\n",
        "  sigma = param_sigma_multiplier / (torch.max(x) - torch.min(x)).item()\n",
        "  smooth = np.zeros(x.cuda().unsqueeze(0).size())\n",
        "  for i in range(epoch):\n",
        "    # call Variable to generate random noise\n",
        "    noise = Variable(x.data.new(x.size()).normal_(mean, sigma**2))\n",
        "    x_mod = (x+noise).unsqueeze(0).cuda()\n",
        "    x_mod.requires_grad_()\n",
        "\n",
        "    y_pred = model(x_mod)\n",
        "    loss_func = torch.nn.CrossEntropyLoss()\n",
        "    loss = loss_func(y_pred, y.cuda().unsqueeze(0))\n",
        "    loss.backward()\n",
        "\n",
        "    # like the method in saliency map\n",
        "    smooth += x_mod.grad.abs().detach().cpu().data.numpy()\n",
        "  smooth = normalize(smooth / epoch) # don't forget to normalize\n",
        "  # smooth = smooth / epoch # try this line to answer the question\n",
        "  return smooth\n",
        "\n",
        "# images, labels = train_set.getbatch(img_indices)\n",
        "smooth = []\n",
        "for i, l in zip(images, labels):\n",
        "  smooth.append(smooth_grad(i, l, model, 500, 0.4))\n",
        "smooth = np.stack(smooth)\n",
        "# print(smooth.shape)\n",
        "\n",
        "fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))\n",
        "for row, target in enumerate([images, smooth]):\n",
        "  for column, img in enumerate(target):\n",
        "    axs[row][column].imshow(np.transpose(img.reshape(3,128,128), (1,2,0)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Z9o2c3jTlwG"
      },
      "source": [
        "## Filter Explanation (Q14~17)\n",
        "\n",
        "這裡我們想要知道某一個 filter 到底認出了什麼。我們會做以下兩件事情：\n",
        "- Filter activation: 挑幾張圖片出來，看看圖片中哪些位置會 activate 該 filter\n",
        "- Filter visualization: 找出怎樣的 image 可以最大程度的 activate 該 filter\n",
        "\n",
        "In this part, we want to know what a specific filter recognize, we'll do\n",
        "- Filter activation: pick up some images, and check which part of the image activates the filter\n",
        "- Filter visualization: look for which kind of image can activate the filter the most\n",
        "\n",
        "實作上比較困難的地方是，通常我們是直接把 image 丟進 model，一路 forward 到底。如：\n",
        "\n",
        "The problem is that, in normal case, we'll directly feed the image to the model, for example,\n",
        "```\n",
        "loss = model(image)\n",
        "loss.backward()\n",
        "```\n",
        "我們要怎麼得到中間某層 CNN 的 output? \n",
        "\n",
        "當然我們可以直接修改 model definition，讓 forward 不只 return loss，也 return activation map。但這樣的寫法麻煩了，更改了 forward 的 output 可能會讓其他部分的 code 要跟著改動。因此 pytorch 提供了方便的 solution: **hook**，以下我們會再介紹。\n",
        "\n",
        "How can we get the output of a specific layer of CNN? \n",
        "\n",
        "We can modify the model definition, make the forward function not only return loss but also retrun the activation map. But this is difficult to maintain the code. As a result, pytorch offers a better solution: **hook**\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BRqCZNd7-Hjb"
      },
      "outputs": [],
      "source": [
        "model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sVJlyj01b_A4"
      },
      "outputs": [],
      "source": [
        "def normalize(image):\n",
        "  return (image - image.min()) / (image.max() - image.min())\n",
        "\n",
        "layer_activations = None\n",
        "def filter_explanation(x, model, cnnid, filterid, iteration=100, lr=1):\n",
        "  # x: input image\n",
        "  # cnnid: cnn layer id\n",
        "  # filterid: which filter\n",
        "  model.eval()\n",
        "\n",
        "  def hook(model, input, output):\n",
        "    global layer_activations\n",
        "    layer_activations = output\n",
        "  \n",
        "  hook_handle = model.cnn[cnnid].register_forward_hook(hook)\n",
        "  # When the model forwards through the layer[cnnid], it needs to call the hook function first\n",
        "  # The hook function save the output of the layer[cnnid]\n",
        "  # After forwarding, we'll have the loss and the layer activation\n",
        "\n",
        "  # Filter activation: x passing the filter will generate the activation map\n",
        "  model(x.cuda()) # forward\n",
        "\n",
        "  # Based on the filterid given by the function argument, pick up the specific filter's activation map\n",
        "  # We just need to plot it, so we can detach from graph and save as cpu tensor\n",
        "  filter_activations = layer_activations[:, filterid, :, :].detach().cpu()\n",
        "  \n",
        "  # Filter visualization: find the image that can activate the filter the most\n",
        "  x = x.cuda()\n",
        "  x.requires_grad_()\n",
        "  # input image gradient\n",
        "  optimizer = Adam([x], lr=lr)\n",
        "  # Use optimizer to modify the input image to amplify filter activation\n",
        "  for iter in range(iteration):\n",
        "    optimizer.zero_grad()\n",
        "    model(x)\n",
        "    \n",
        "    objective = -layer_activations[:, filterid, :, :].sum()\n",
        "    # We want to maximize the filter activation's summation\n",
        "    # So we add a negative sign\n",
        "    \n",
        "    objective.backward()\n",
        "    # Calculate the partial differential value of filter activation to input image\n",
        "    optimizer.step()\n",
        "    # Modify input image to maximize filter activation\n",
        "  filter_visualizations = x.detach().cpu().squeeze()\n",
        "\n",
        "  # Don't forget to remove the hook\n",
        "  hook_handle.remove()\n",
        "  # The hook will exist after the model register it, so you have to remove it after used\n",
        "  # Just register a new hook if you want to use it\n",
        "\n",
        "  return filter_activations, filter_visualizations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l7Q-0mxV-xoo"
      },
      "outputs": [],
      "source": [
        "images, labels = train_set.getbatch(img_indices)\n",
        "filter_activations, filter_visualizations = filter_explanation(images, model, cnnid=6, filterid=0, iteration=100, lr=0.1)\n",
        "\n",
        "fig, axs = plt.subplots(3, len(img_indices), figsize=(15, 8))\n",
        "for i, img in enumerate(images):\n",
        "  axs[0][i].imshow(img.permute(1, 2, 0))\n",
        "# Plot filter activations\n",
        "for i, img in enumerate(filter_activations):\n",
        "  axs[1][i].imshow(normalize(img))\n",
        "# Plot filter visualization\n",
        "for i, img in enumerate(filter_visualizations):\n",
        "  axs[2][i].imshow(normalize(img.permute(1, 2, 0)))\n",
        "plt.show()\n",
        "plt.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FAaUtuvl7Chg"
      },
      "outputs": [],
      "source": [
        "images, labels = train_set.getbatch(img_indices)\n",
        "filter_activations, filter_visualizations = filter_explanation(images, model, cnnid=23, filterid=0, iteration=100, lr=0.1)\n",
        "\n",
        "# Plot filter activations\n",
        "fig, axs = plt.subplots(3, len(img_indices), figsize=(15, 8))\n",
        "for i, img in enumerate(images):\n",
        "  axs[0][i].imshow(img.permute(1, 2, 0))\n",
        "for i, img in enumerate(filter_activations):\n",
        "  axs[1][i].imshow(normalize(img))\n",
        "for i, img in enumerate(filter_visualizations):\n",
        "  axs[2][i].imshow(normalize(img.permute(1, 2, 0)))\n",
        "plt.show()\n",
        "plt.close()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##Integrated Gradients (Q18~20)"
      ],
      "metadata": {
        "id": "VgytLPR5Gw1c"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SK3JVklGiqw6"
      },
      "outputs": [],
      "source": [
        "class IntegratedGradients():\n",
        "    def __init__(self, model):\n",
        "        self.model = model\n",
        "        self.gradients = None\n",
        "        # Put model in evaluation mode\n",
        "        self.model.eval()\n",
        "\n",
        "    def generate_images_on_linear_path(self, input_image, steps):\n",
        "        # Generate scaled xbar images\n",
        "        xbar_list = [input_image*step/steps for step in range(steps)]\n",
        "        return xbar_list\n",
        "\n",
        "    def generate_gradients(self, input_image, target_class):\n",
        "        # We want to get the gradients of the input image\n",
        "        input_image.requires_grad=True\n",
        "        # Forward\n",
        "        model_output = self.model(input_image)\n",
        "        # Zero grads\n",
        "        self.model.zero_grad()\n",
        "        # Target for backprop\n",
        "        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_().cuda()\n",
        "        one_hot_output[0][target_class] = 1\n",
        "        # Backward\n",
        "        model_output.backward(gradient=one_hot_output)\n",
        "        self.gradients = input_image.grad\n",
        "        # Convert Pytorch variable to numpy array\n",
        "        # [0] to get rid of the first channel (1,3,128,128)\n",
        "        gradients_as_arr = self.gradients.data.cpu().numpy()[0]\n",
        "        return gradients_as_arr\n",
        "\n",
        "    def generate_integrated_gradients(self, input_image, target_class, steps):\n",
        "        # Generate xbar images\n",
        "        xbar_list = self.generate_images_on_linear_path(input_image, steps)\n",
        "        # Initialize an image composed of zeros\n",
        "        integrated_grads = np.zeros(input_image.size())\n",
        "        for xbar_image in xbar_list:\n",
        "            # Generate gradients from xbar images\n",
        "            single_integrated_grad = self.generate_gradients(xbar_image, target_class)\n",
        "            # Add rescaled grads from xbar images\n",
        "            integrated_grads = integrated_grads + single_integrated_grad/steps\n",
        "        # [0] to get rid of the first channel (1,3,128,128)\n",
        "        return integrated_grads[0]\n",
        "\n",
        "def normalize(image):\n",
        "  return (image - image.min()) / (image.max() - image.min())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YiVRoNgFwTY1"
      },
      "outputs": [],
      "source": [
        "# put the image to cuda\n",
        "images, labels = train_set.getbatch(img_indices)\n",
        "images = images.cuda()\n",
        "\n",
        "IG = IntegratedGradients(model)\n",
        "integrated_grads = []\n",
        "for i, img in enumerate(images):\n",
        "  img = img.unsqueeze(0)\n",
        "  integrated_grads.append(IG.generate_integrated_gradients(img, labels[i], 10))\n",
        "fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))\n",
        "for i, img in enumerate(images):\n",
        "  axs[0][i].imshow(img.cpu().permute(1, 2, 0))\n",
        "for i, img in enumerate(integrated_grads):\n",
        "  axs[1][i].imshow(np.moveaxis(normalize(img),0,-1))\n",
        "plt.show()\n",
        "plt.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "####Congratulations for finishing Part 1 of Homework 9 !! Almost done !!!"
      ],
      "metadata": {
        "id": "9pqTWud-RBsP"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#**Homework 9 - Explainable AI (Part 2 BERT)**"
      ],
      "metadata": {
        "id": "E8aO_xSiHEqg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "##Attention Visualization (Q21~24)\n",
        "You are highly recommended to visualize on this website directly: https://exbert.net/exBERT.html"
      ],
      "metadata": {
        "id": "6DGbpc88NalA"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dmo6q27Eo-49"
      },
      "outputs": [],
      "source": [
        "from IPython import display\n",
        "display.IFrame(\"https://exbert.net/exBERT.html\", width=1600, height=1600)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xzxohy5PYmbx"
      },
      "source": [
        "## Import Packages (For Questions 25 - 30)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4zk9AhBRa5DJ"
      },
      "outputs": [],
      "source": [
        "# Install transformers\n",
        "!pip install transformers==4.5.0\n",
        "\n",
        "# Import all packages needed\n",
        "import numpy as np\n",
        "import random\n",
        "import torch\n",
        "\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.metrics import pairwise_distances\n",
        "from transformers import BertModel, BertTokenizerFast\n",
        "\n",
        "##### For Displaying Traditional Chinese in Colab when Drawing with Matplotlib #####\n",
        "# Colab 進行matplotlib繪圖時顯示繁體中文\n",
        "# 下載台北思源黑體並命名taipei_sans_tc_beta.ttf，移至指定路徑\n",
        "!gdown --id '1JWHUSlcPwoEzmr0VE6J71jcnwinH10G6' --output taipei_sans_tc_beta.ttf\n",
        "\n",
        "from matplotlib.font_manager import FontProperties\n",
        "import matplotlib.pyplot as plt \n",
        "\n",
        "# 自定義字體變數\n",
        "myfont = FontProperties(fname=r'taipei_sans_tc_beta.ttf')\n",
        "\n",
        "# !!!!後續在相關函式中增加fontproperties=myfont屬性即可!!!!\n",
        "##### Code from https://colab.research.google.com/github/willismax/matplotlib_show_chinese_in_colab/blob/master/matplotlib_show_chinese_in_colab.ipynb #####\n",
        "\n",
        "plt.rcParams['figure.figsize'] = [12, 10]\n",
        "\n",
        "# Fix random seed for reproducibility\n",
        "def same_seeds(seed):\n",
        "\ttorch.manual_seed(seed)\n",
        "\tif torch.cuda.is_available():\n",
        "\t\ttorch.cuda.manual_seed(seed)\n",
        "\t\ttorch.cuda.manual_seed_all(seed)\n",
        "\tnp.random.seed(seed)\n",
        "\trandom.seed(seed)\n",
        "\ttorch.backends.cudnn.benchmark = False\n",
        "\ttorch.backends.cudnn.deterministic = True\n",
        "\n",
        "same_seeds(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7HTp9kBsxouS"
      },
      "source": [
        "## Embedding Visualization (Q25~27)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### 我們現在有一個預訓練好的模型，並且在閱讀理解上微調過。 \n",
        "閱讀理解需要四個步驟﹕（這些步驟並**不**按照順序排列)\n",
        "\n",
        "1. 將類似的文字分羣 (根據文字在文章中的關係)\n",
        "\n",
        "2. 提取答案\n",
        "\n",
        "3. 將類似的文字分羣 (根據文字的意思)\n",
        "\n",
        "4. 從文章中尋找與問題有關的資訊\n",
        "\n",
        "#### 你可以在只看見模型 hidden states embedding 的情況下，找出各個layer的功能嗎?\n",
        "\n",
        "<br> \n",
        "\n",
        "#### We have a pre-trained model which is fine-tuned for QA. \n",
        "Solving QA requires 4 steps:　(steps are **NOT** in order)\n",
        "\n",
        "1. Clustering similar words together (based on relation of words in context)\n",
        "\n",
        "2. Answer extraction\n",
        "\n",
        "3. Clustering similar words together (based on meaning of words)\n",
        "\n",
        "4. Matching questions with relevant information in context\n",
        "\n",
        "\n",
        "#### Can you find out the functionalities of each layer just by looking into the embedding of hidden states?"
      ],
      "metadata": {
        "id": "l7KtyciK1KW9"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2E6Np-zFstOj"
      },
      "source": [
        "### Download Tokenizers and Models' hidden states"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dgEkI4N7jHxX"
      },
      "outputs": [],
      "source": [
        "!gdown --id '1h3akaNdouiIGItOqEs6kUZE-hAF0QeDk' --output hw9_bert.zip\n",
        "!unzip hw9_bert.zip"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "syC4-t0Yuqp3"
      },
      "source": [
        "### Load Tokenizers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "joSq-69fkV3x"
      },
      "outputs": [],
      "source": [
        "Tokenizer = BertTokenizerFast.from_pretrained(\"hw9_bert/Tokenizer\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qBa0jYnNriQB"
      },
      "source": [
        "### What to Visualize?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eOWh2Ur7nOsM"
      },
      "outputs": [],
      "source": [
        "contexts, questions, answers = [], [], []\n",
        "\n",
        "# Question 1\n",
        "contexts += [\"Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 July 1856 – 7 January 1943) was a Serbian American inventor, electrical engineer, \\\n",
        "            mechanical engineer, physicist, and futurist best known for his contributions to the design of the modern alternating current \\\n",
        "            (AC) electricity supply system.\"]\n",
        "questions += [\"In what year was Nikola Tesla born?\"]\n",
        "answers += [\"1856\"]\n",
        "\n",
        "# Question 2\n",
        "contexts += ['Currently detention is one of the most common punishments in schools in the United States, the UK, Ireland, Singapore and other countries. \\\n",
        "            It requires the pupil to remain in school at a given time in the school day (such as lunch, recess or after school); or even to attend \\\n",
        "            school on a non-school day, e.g. \"Saturday detention\" held at some schools. During detention, students normally have to sit in a classroom \\\n",
        "            and do work, write lines or a punishment essay, or sit quietly.']\n",
        "questions += ['What is a common punishment in the UK and Ireland?']\n",
        "answers += ['detention']\n",
        "\n",
        "# Question 3\n",
        "contexts += ['Wolves are afraid of cats. Sheep are afraid of wolves. Mice are afraid of sheep. Gertrude is a mouse. Jessica is a mouse. \\\n",
        "            Emily is a wolf. Cats are afraid of sheep. Winona is a wolf.']\n",
        "questions += ['What is Emily afraid of?']\n",
        "answers += ['cats']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "keDPUw6onTUz"
      },
      "source": [
        "### TODO\n",
        "This is the only part you need to modify to answer Q25~27."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q1mF0ieFrfDD"
      },
      "outputs": [],
      "source": [
        "# Choose from 1, 2, 3\n",
        "QUESTION = 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rJC_zL5a1A0v"
      },
      "source": [
        "### Visualization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wc4LAmuQarE7"
      },
      "outputs": [],
      "source": [
        "# Tokenize and encode question and paragraph into model's input format\n",
        "inputs = Tokenizer(questions[QUESTION-1], contexts[QUESTION-1], return_tensors='pt') \n",
        "\n",
        "# Get the [start, end] positions of [question, context] in encoded sequence for plotting\n",
        "question_start, question_end = 1, inputs['input_ids'][0].tolist().index(102) - 1\n",
        "context_start, context_end = question_end + 2, len(inputs['input_ids'][0]) - 2\n",
        "\n",
        "outputs_hidden_states = torch.load(f\"hw9_bert/output/model_q{QUESTION}\")\n",
        "\n",
        "##### Traverse hidden state of all layers #####\n",
        "# \"outputs_hidden_state\" is a tuple with 13 elements, the 1st element is embedding output, the other 12 elements are attention hidden states of layer 1 - 12\n",
        "for layer_index, embeddings in enumerate(outputs_hidden_states[1:]): # 1st element is skipped\n",
        " \n",
        "    # \"embeddings\" has shape [1, sequence_length, 768], where 768 is the dimension of BERT's hidden state\n",
        "    # Dimension of \"embeddings\" is reduced from 768 to 2 using PCA (Principal Component Analysis)\n",
        "    reduced_embeddings = PCA(n_components=2, random_state=0).fit_transform(embeddings[0])\n",
        "\n",
        "    ##### Draw embedding of each token ##### \n",
        "    for i, token_id in enumerate(inputs['input_ids'][0]):\n",
        "        x, y = reduced_embeddings[i] # Embedding has 2 dimensions, each corresponds to a point\n",
        "        word = Tokenizer.decode(token_id) # Decode token back to word\n",
        "        # Scatter points of answer, question and context in different colors\n",
        "        if word in answers[QUESTION-1].split(): # Check if word in answer\n",
        "            plt.scatter(x, y, color='blue', marker='d') \n",
        "        elif question_start <= i <= question_end:\n",
        "            plt.scatter(x, y, color='red')\n",
        "        elif context_start <= i <= context_end:\n",
        "            plt.scatter(x, y, color='green')\n",
        "        else: # skip special tokens [CLS], [SEP]\n",
        "            continue\n",
        "        plt.text(x + 0.1, y + 0.2, word, fontsize=12) # Plot word next to its point\n",
        "    \n",
        "    # Plot \"empty\" points to show labels\n",
        "    plt.plot([], label='answer', color='blue', marker='d')  \n",
        "    plt.plot([], label='question', color='red', marker='o')\n",
        "    plt.plot([], label='context', color='green', marker='o')\n",
        "    plt.legend(loc='best') # Display the area describing the elements in the plot\n",
        "    plt.title('Layer ' + str(layer_index + 1)) # Add title to the plot\n",
        "    plt.show() # Show the plot"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##Embedding Analysis (Q28~30)"
      ],
      "metadata": {
        "id": "300HEq1VP4rN"
      }
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y1H-xMOn-dLv"
      },
      "source": [
        "model = BertModel.from_pretrained('bert-base-chinese', output_hidden_states=True).eval()\n",
        "tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7MV_qNsmHqfD"
      },
      "source": [
        "### What to Visualize?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-WHhO1J3Hx6R"
      },
      "outputs": [],
      "source": [
        "# Sentences for visualization\n",
        "sentences = []\n",
        "sentences += [\"今天買了蘋果來吃\"]\n",
        "sentences += [\"進口蘋果（富士)平均每公斤下跌12.3%\"]\n",
        "sentences += [\"蘋果茶真難喝\"]\n",
        "sentences += [\"老饕都知道智利的蘋果季節即將到來\"]\n",
        "sentences += [\"進口蘋果因防止水分流失故添加人工果糖\"]\n",
        "sentences += [\"蘋果即將於下月發振新款iPhone\"]\n",
        "sentences += [\"蘋果獲新Face ID專利\"]\n",
        "sentences += [\"今天買了蘋果手機\"]\n",
        "sentences += [\"蘋果的股價又跌了\"]\n",
        "sentences += [\"蘋果押寶指紋辨識技術\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UV3pJo85slwz"
      },
      "source": [
        "### TODO\n",
        "\n",
        "This is the only part you need to modify to answer Q28~30."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4tdX9HClgz8p"
      },
      "outputs": [],
      "source": [
        "# Index of word selected for embedding comparison. E.g. For sentence \"蘋果茶真難喝\", if index is 0, \"蘋 is selected\"\n",
        "# The first line is the indexes for 蘋; the second line is the indexes for 果\n",
        "select_word_index = [4, 2, 0, 8, 2, 0, 0, 4, 0, 0]\n",
        "# select_word_index = [5, 3, 1, 9, 3, 1, 1, 5, 1, 1]\n",
        "\n",
        "def euclidean_distance(a, b):\n",
        "    # Compute euclidean distance (L2 norm) between two numpy vectors a and b\n",
        "    return 0\n",
        "\n",
        "def cosine_similarity(a, b):\n",
        "    # Compute cosine similarity between two numpy vectors a and b\n",
        "    return 0\n",
        "\n",
        "# Metric for comparison. Choose from euclidean_distance, cosine_similarity\n",
        "METRIC = euclidean_distance\n",
        "\n",
        "def get_select_embedding(output, tokenized_sentence, select_word_index):\n",
        "    # The layer to visualize, choose from 0 to 12\n",
        "    LAYER = 12\n",
        "    # Get selected layer's hidden state\n",
        "    hidden_state = output.hidden_states[LAYER][0]\n",
        "    # Convert select_word_index in sentence to select_token_index in tokenized sentence\n",
        "    select_token_index = tokenized_sentence.word_to_tokens(select_word_index).start\n",
        "    # Return embedding of selected word\n",
        "    return hidden_state[select_token_index].numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZFAVj7p-71l"
      },
      "source": [
        "### Visualization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VbM2xA-e-622"
      },
      "outputs": [],
      "source": [
        "# Tokenize and encode sentences into model's input format\n",
        "tokenized_sentences = [tokenizer(sentence, return_tensors='pt') for sentence in sentences]\n",
        "\n",
        "# Input encoded sentences into model and get outputs \n",
        "with torch.no_grad():\n",
        "    outputs = [model(**tokenized_sentence) for tokenized_sentence in tokenized_sentences]\n",
        "\n",
        "# Get embedding of selected word(s) in sentences. \"embeddings\" has shape (len(sentences), 768), where 768 is the dimension of BERT's hidden state\n",
        "embeddings = [get_select_embedding(outputs[i], tokenized_sentences[i], select_word_index[i]) for i in range(len(outputs))]\n",
        "\n",
        "# Pairwse comparsion of sentences' embeddings using the metirc defined. \"similarity_matrix\" has shape [len(sentences), len(sentences)]\n",
        "similarity_matrix = pairwise_distances(embeddings, metric=METRIC) \n",
        "\n",
        "##### Plot the similarity matrix #####\n",
        "plt.rcParams['figure.figsize'] = [12, 10] # Change figure size of the plot\n",
        "plt.imshow(similarity_matrix) # Display an image in the plot\n",
        "plt.colorbar() # Add colorbar to the plot\n",
        "plt.yticks(ticks=range(len(sentences)), labels=sentences, fontproperties=myfont) # Set tick locations and labels (sentences) of y-axis\n",
        "plt.title('Comparison of BERT Word Embeddings') # Add title to the plot\n",
        "for (i,j), label in np.ndenumerate(similarity_matrix): # np.ndenumerate is 2D version of enumerate\n",
        "    plt.text(i, j, '{:.2f}'.format(label), ha='center', va='center') # Add values in similarity_matrix to the corresponding position in the plot\n",
        "plt.show() # Show the plot "
      ]
    }
  ]
}